{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "composite-transparency",
   "metadata": {},
   "source": [
    "# 表征模型使用教程"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "biblical-camcorder",
   "metadata": {},
   "source": [
    "表征模型(TS2Vec)属于自监督模型里的一种，主要是希望能够学习到一种通用的特征表达用于下游任务；当前主流的自监督学习主要有基于生成式和基于对比学习的方法，当前案例使用的TS2Vec模型是一种基于对比学习的自监督模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "lasting-collins",
   "metadata": {},
   "source": [
    "自监督模型的使用一般分为两个阶段：\n",
    "1. 不涉及任何下游任务，使用无标签的数据进行预训练\n",
    "2. 使用带标签的数据在下游任务上 Fine-tune"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "robust-isaac",
   "metadata": {},
   "source": [
    "TS2Vec结合下游任务的使用同样遵循自监督模型的使用范式，分为2个阶段：\n",
    "1. 表征模型训练\n",
    "2. 将表征模型的输出用于下游任务(当前案例的下游任务为预测任务)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为兼顾初学者和有一定的经验的开发者，本文给出两种表征任务的使用方法：\n",
    "1. 表征模型和下游任务相结合的pipeline，对初学者的使用非常友好\n",
    "2. 表征模型和下游任务解耦，详细展示表征模型和下游任务是如何相结合使用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用方法一：表征模型和下游任务相结合的pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 准备数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/ssd3/zhangshuo18/libs/anaconda3/envs/zs37/lib/python3.7/site-packages/setuptools/distutils_patch.py:17: UserWarning: Setuptools is replacing distutils\n",
      "  warnings.warn(\"Setuptools is replacing distutils\")\n",
      "/ssd3/zhangshuo18/timeseries/project/2022_Q3/bts/paddlets/utils/backtest.py:6: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working\n",
      "  from collections import defaultdict, Iterable\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "                            OT    HUFL   HULL    MUFL   MULL   LUFL   LULL\n",
       "date                                                                      \n",
       "2016-07-01 00:00:00  30.531000   5.827  2.009   1.599  0.462  4.203  1.340\n",
       "2016-07-01 01:00:00  27.787001   5.693  2.076   1.492  0.426  4.142  1.371\n",
       "2016-07-01 02:00:00  27.787001   5.157  1.741   1.279  0.355  3.777  1.218\n",
       "2016-07-01 03:00:00  25.044001   5.090  1.942   1.279  0.391  3.807  1.279\n",
       "2016-07-01 04:00:00  21.948000   5.358  1.942   1.492  0.462  3.868  1.279\n",
       "...                        ...     ...    ...     ...    ...    ...    ...\n",
       "2016-09-21 01:00:00  21.878000  13.396  4.354  11.940  3.198  1.310  0.670\n",
       "2016-09-21 02:00:00  22.230000  12.458  4.354  11.407  2.878  1.127  0.579\n",
       "2016-09-21 03:00:00  22.230000  12.927  4.086  11.655  2.878  1.127  0.579\n",
       "2016-09-21 04:00:00  22.722000  12.324  6.162  11.407  4.655  1.340  0.609\n",
       "2016-09-21 05:00:00  22.511000  14.133  6.497  12.650  5.082  1.614  0.670\n",
       "\n",
       "[1974 rows x 7 columns]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.random.seed(2022)\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import paddle\n",
    "paddle.seed(2022)\n",
    "\n",
    "from paddlets.models.representation.dl.ts2vec import TS2Vec\n",
    "from paddlets.datasets.repository import get_dataset\n",
    "from paddlets.models.representation.task.repr_forecasting import ReprForecasting\n",
    "\n",
    "data = get_dataset('ETTh1')\n",
    "data, _ = data.split('2016-09-22 06:00:00')\n",
    "train_data, test_data = data.split('2016-09-21 05:00:00')\n",
    "train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "                       HUFL   HULL   LUFL   LULL    MUFL   MULL\n",
       "date                                                           \n",
       "2016-07-01 00:00:00   5.827  2.009  4.203  1.340   1.599  0.462\n",
       "2016-07-01 01:00:00   5.693  2.076  4.142  1.371   1.492  0.426\n",
       "2016-07-01 02:00:00   5.157  1.741  3.777  1.218   1.279  0.355\n",
       "2016-07-01 03:00:00   5.090  1.942  3.807  1.279   1.279  0.391\n",
       "2016-07-01 04:00:00   5.358  1.942  3.868  1.279   1.492  0.462\n",
       "...                     ...    ...    ...    ...     ...    ...\n",
       "2016-09-21 01:00:00  13.396  4.354  1.310  0.670  11.940  3.198\n",
       "2016-09-21 02:00:00  12.458  4.354  1.127  0.579  11.407  2.878\n",
       "2016-09-21 03:00:00  12.927  4.086  1.127  0.579  11.655  2.878\n",
       "2016-09-21 04:00:00  12.324  6.162  1.340  0.609  11.407  4.655\n",
       "2016-09-21 05:00:00  14.133  6.497  1.614  0.670  12.650  5.082\n",
       "\n",
       "[1974 rows x 6 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.get_observed_cov()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2022-11-02 18:49:52,748] [paddlets.models.representation.task.repr_forecasting] [INFO] Repr model fit start\n",
      "W1102 18:49:52.935454 88921 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 10.2, Runtime API Version: 10.2\n",
      "W1102 18:49:52.938324 88921 gpu_context.cc:306] device: 0, cuDNN Version: 8.5.\n",
      "/ssd3/zhangshuo18/libs/anaconda3/envs/zs37/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:278: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float32, but right dtype is paddle.int64, the right dtype will convert to paddle.float32\n",
      "  format(lhs_dtype, rhs_dtype, lhs_dtype))\n",
      "[2022-11-02 18:49:59,383] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 000| loss: 1414.767578| 0:00:01s\n",
      "[2022-11-02 18:49:59,575] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 001| loss: 803.826904| 0:00:01s\n",
      "[2022-11-02 18:49:59,738] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 002| loss: 471.702454| 0:00:01s\n",
      "[2022-11-02 18:49:59,895] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 003| loss: 260.660858| 0:00:01s\n",
      "[2022-11-02 18:50:00,052] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 004| loss: 189.227371| 0:00:02s\n",
      "[2022-11-02 18:50:00,223] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 005| loss: 145.273895| 0:00:02s\n",
      "[2022-11-02 18:50:00,398] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 006| loss: 111.410248| 0:00:02s\n",
      "[2022-11-02 18:50:00,569] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 007| loss: 80.815414| 0:00:02s\n",
      "[2022-11-02 18:50:00,749] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 008| loss: 68.463318| 0:00:02s\n",
      "[2022-11-02 18:50:00,926] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 009| loss: 51.364281| 0:00:02s\n",
      "[2022-11-02 18:50:01,109] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 010| loss: 44.166645| 0:00:03s\n",
      "[2022-11-02 18:50:01,282] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 011| loss: 43.780029| 0:00:03s\n",
      "[2022-11-02 18:50:01,453] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 012| loss: 41.319237| 0:00:03s\n",
      "[2022-11-02 18:50:01,616] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 013| loss: 35.690578| 0:00:03s\n",
      "[2022-11-02 18:50:01,794] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 014| loss: 30.054024| 0:00:03s\n",
      "[2022-11-02 18:50:01,948] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 015| loss: 28.020060| 0:00:04s\n",
      "[2022-11-02 18:50:02,113] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 016| loss: 27.151669| 0:00:04s\n",
      "[2022-11-02 18:50:02,276] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 017| loss: 22.419031| 0:00:04s\n",
      "[2022-11-02 18:50:02,435] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 018| loss: 21.138212| 0:00:04s\n",
      "[2022-11-02 18:50:02,601] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 019| loss: 20.676611| 0:00:04s\n",
      "[2022-11-02 18:50:02,603] [paddlets.models.representation.task.repr_forecasting] [INFO] Repr model fit end\n",
      "[2022-11-02 18:50:04,736] [paddlets.models.representation.task.repr_forecasting] [INFO] Downstream model fit start\n",
      "[2022-11-02 18:50:05,003] [paddlets.models.representation.task.repr_forecasting] [INFO] Downstream model fit end\n"
     ]
    }
   ],
   "source": [
    "ts2vec_params = { \"segment_size\": 200, \n",
    "                  \"repr_dims\": 320,\n",
    "                  \"batch_size\": 32,\n",
    "                        \"sampling_stride\": 200,\n",
    "                         \"max_epochs\": 20}\n",
    "model = ReprForecasting(in_chunk_len=200,\n",
    "                                out_chunk_len=24,\n",
    "                                sampling_stride=1,\n",
    "                                repr_model=TS2Vec,\n",
    "                                repr_model_params=ts2vec_params)\n",
    "model.fit(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "                            OT\n",
       "2016-09-21 06:00:00  22.991394\n",
       "2016-09-21 07:00:00  23.517416\n",
       "2016-09-21 08:00:00  24.104179\n",
       "2016-09-21 09:00:00  24.243126\n",
       "2016-09-21 10:00:00  24.298140\n",
       "2016-09-21 11:00:00  24.367647\n",
       "2016-09-21 12:00:00  24.640148\n",
       "2016-09-21 13:00:00  24.640419\n",
       "2016-09-21 14:00:00  24.576252\n",
       "2016-09-21 15:00:00  24.295414\n",
       "2016-09-21 16:00:00  23.657946\n",
       "2016-09-21 17:00:00  23.767244\n",
       "2016-09-21 18:00:00  23.669827\n",
       "2016-09-21 19:00:00  23.264309\n",
       "2016-09-21 20:00:00  22.973028\n",
       "2016-09-21 21:00:00  22.930428\n",
       "2016-09-21 22:00:00  22.845171\n",
       "2016-09-21 23:00:00  22.806917\n",
       "2016-09-22 00:00:00  22.769144\n",
       "2016-09-22 01:00:00  23.296446\n",
       "2016-09-22 02:00:00  23.689632\n",
       "2016-09-22 03:00:00  24.013086\n",
       "2016-09-22 04:00:00  23.938864\n",
       "2016-09-22 05:00:00  23.524876"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 内置的API: `backtest`可用于预测与评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/ssd3/zhangshuo18/libs/anaconda3/envs/zs37/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "[2022-11-02 11:50:09,299] [paddlets.utils.utils] [WARNING] error occurred while import autots, err: XGBoost Library (libxgboost.so) could not be loaded.\n",
      "Likely causes:\n",
      "  * OpenMP runtime is not installed (vcomp140.dll or libgomp-1.dll for Windows, libomp.dylib for Mac OSX, libgomp.so for Linux and other UNIX-like OSes). Mac OSX users: Run `brew install libomp` to install OpenMP runtime.\n",
      "  * You are running 32-bit Python on a 64-bit OS\n",
      "Error message(s): ['dlopen: cannot load any more object with static TLS']\n",
      "\n",
      "Backtest Progress: 100%|██████████| 2/2 [00:00<00:00,  3.68it/s]\n"
     ]
    }
   ],
   "source": [
    "from paddlets.utils.backtest import backtest\n",
    "score, predicts = backtest(\n",
    "            data,\n",
    "            model, \n",
    "            start=\"2016-09-21 06:00:00\", \n",
    "            predict_window=24, \n",
    "            stride=24,\n",
    "            return_predicts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用方法二：表征模型和下游回归任务解耦"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "painted-gauge",
   "metadata": {},
   "source": [
    "# 第一阶段：\n",
    "1.表征模型的训练\n",
    "\n",
    "2.输出训练集和测试集的表征结果"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "facial-replacement",
   "metadata": {},
   "source": [
    "# 准备数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "practical-minimum",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "                            OT    HUFL   HULL    MUFL   MULL   LUFL   LULL\n",
       "date                                                                      \n",
       "2016-07-01 00:00:00  30.531000   5.827  2.009   1.599  0.462  4.203  1.340\n",
       "2016-07-01 01:00:00  27.787001   5.693  2.076   1.492  0.426  4.142  1.371\n",
       "2016-07-01 02:00:00  27.787001   5.157  1.741   1.279  0.355  3.777  1.218\n",
       "2016-07-01 03:00:00  25.044001   5.090  1.942   1.279  0.391  3.807  1.279\n",
       "2016-07-01 04:00:00  21.948000   5.358  1.942   1.492  0.462  3.868  1.279\n",
       "...                        ...     ...    ...     ...    ...    ...    ...\n",
       "2016-09-21 01:00:00  21.878000  13.396  4.354  11.940  3.198  1.310  0.670\n",
       "2016-09-21 02:00:00  22.230000  12.458  4.354  11.407  2.878  1.127  0.579\n",
       "2016-09-21 03:00:00  22.230000  12.927  4.086  11.655  2.878  1.127  0.579\n",
       "2016-09-21 04:00:00  22.722000  12.324  6.162  11.407  4.655  1.340  0.609\n",
       "2016-09-21 05:00:00  22.511000  14.133  6.497  12.650  5.082  1.614  0.670\n",
       "\n",
       "[1974 rows x 7 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from paddlets.models.representation.dl.ts2vec import TS2Vec\n",
    "from paddlets.datasets.repository import get_dataset\n",
    "\n",
    "data = get_dataset('ETTh1')\n",
    "data, _ = data.split('2016-09-22 06:00:00')\n",
    "train_data, test_data = data.split('2016-09-21 05:00:00')\n",
    "train_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bearing-title",
   "metadata": {},
   "source": [
    "# 表征模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/ssd3/zhangshuo18/libs/anaconda3/envs/zs37/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:278: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float32, but right dtype is paddle.int64, the right dtype will convert to paddle.float32\n",
      "  format(lhs_dtype, rhs_dtype, lhs_dtype))\n",
      "[2022-11-02 11:50:20,539] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 000| loss: 78.309945| 0:00:10s\n",
      "[2022-11-02 11:50:30,399] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 001| loss: 4.548435| 0:00:19s\n",
      "[2022-11-02 11:50:40,499] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 002| loss: 3.869072| 0:00:30s\n",
      "[2022-11-02 11:50:50,391] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 003| loss: 3.646648| 0:00:39s\n",
      "[2022-11-02 11:51:00,268] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 004| loss: 3.551296| 0:00:49s\n",
      "[2022-11-02 11:51:10,484] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 005| loss: 3.591498| 0:01:00s\n",
      "[2022-11-02 11:51:20,490] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 006| loss: 3.517901| 0:01:10s\n",
      "[2022-11-02 11:51:30,249] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 007| loss: 3.508513| 0:01:19s\n",
      "[2022-11-02 11:51:39,716] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 008| loss: 3.427224| 0:01:29s\n",
      "[2022-11-02 11:51:49,544] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 009| loss: 3.543946| 0:01:39s\n",
      "[2022-11-02 11:51:59,348] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 010| loss: 3.488514| 0:01:48s\n",
      "[2022-11-02 11:52:08,890] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 011| loss: 3.485540| 0:01:58s\n",
      "[2022-11-02 11:52:18,570] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 012| loss: 3.476332| 0:02:08s\n",
      "[2022-11-02 11:52:28,162] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 013| loss: 3.451701| 0:02:17s\n",
      "[2022-11-02 11:52:38,172] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 014| loss: 3.503496| 0:02:27s\n",
      "[2022-11-02 11:52:48,052] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 015| loss: 3.497426| 0:02:37s\n",
      "[2022-11-02 11:52:58,101] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 016| loss: 3.467514| 0:02:47s\n",
      "[2022-11-02 11:53:07,933] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 017| loss: 3.470402| 0:02:57s\n",
      "[2022-11-02 11:53:17,619] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 018| loss: 3.435449| 0:03:07s\n",
      "[2022-11-02 11:53:27,177] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 019| loss: 3.438016| 0:03:16s\n"
     ]
    }
   ],
   "source": [
    "#实例化TS2Vect对象\n",
    "ts2vec = TS2Vec(\n",
    "    segment_size=200, #最大序列长度\n",
    "    repr_dims=320, #表征输出的维度大小\n",
    "    batch_size=32,\n",
    "    max_epochs=20,\n",
    ")\n",
    "#训练\n",
    "ts2vec.fit(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "knowing-department",
   "metadata": {},
   "source": [
    "# 输出训练集和测试集的表征结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fluid-favor",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1999/1999 [00:02<00:00, 958.44it/s]\n"
     ]
    }
   ],
   "source": [
    "sliding_len = 200 \n",
    "all_reprs = ts2vec.encode(data, sliding_len=sliding_len) \n",
    "split_tag = len(train_data['OT'])\n",
    "train_reprs = all_reprs[:, :split_tag]\n",
    "test_reprs = all_reprs[:, split_tag:]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "frequent-pavilion",
   "metadata": {},
   "source": [
    "# 第二阶段\n",
    "1. 构建回归模型的训练和测试样本\n",
    "\n",
    "2. 训练和预测"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "magnetic-technology",
   "metadata": {},
   "source": [
    "# 构建回归模型的训练和测试样本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "hybrid-apple",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_pred_samples(features, data, pred_len, drop=0):\n",
    "    n = data.shape[1]\n",
    "    features = features[:, :-pred_len]\n",
    "    labels = np.stack([ data[:, i:1+n+i-pred_len] for i in range(pred_len)], axis=2)[:, 1:]\n",
    "    features = features[:, drop:]\n",
    "    labels = labels[:, drop:]\n",
    "    return features.reshape(-1, features.shape[-1]), \\\n",
    "            labels.reshape(-1, labels.shape[2]*labels.shape[3])\n",
    "\n",
    "pre_len = 24 #预测未来时刻的长度\n",
    "\n",
    "#构建训练样本\n",
    "train_to_numpy = train_data.to_numpy()\n",
    "train_to_numpy = np.expand_dims(train_to_numpy, 0) #保持和encode输出的维度一致\n",
    "train_features, train_labels = generate_pred_samples(train_reprs, train_to_numpy, pre_len, drop=sliding_len)\n",
    "\n",
    "#构建测试样本\n",
    "test_to_numpy = test_data.to_numpy()\n",
    "test_to_numpy = np.expand_dims(test_to_numpy, 0) #同上\n",
    "test_features, test_labels = generate_pred_samples(test_reprs, test_to_numpy, pre_len) #构造样本"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "supposed-island",
   "metadata": {},
   "source": [
    "# 训练及预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "unlike-excuse",
   "metadata": {},
   "outputs": [],
   "source": [
    "#训练\n",
    "from sklearn.linear_model import Ridge\n",
    "lr = Ridge(alpha=0.1)\n",
    "lr.fit(train_features, train_labels)\n",
    "\n",
    "#预测\n",
    "test_pred = lr.predict(test_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[23.411926  , 14.156788  ,  5.5864105 ,  2.0591881 ,  0.8507492 ,\n",
       "        12.033666  ,  3.8670216 , 24.307272  , 12.970997  ,  5.047309  ,\n",
       "         1.5718118 ,  0.83765304, 11.449925  ,  3.6241622 , 25.072916  ,\n",
       "        11.619985  ,  4.647399  ,  1.2839303 ,  0.8075943 , 10.430354  ,\n",
       "         3.267183  , 25.813192  , 10.340514  ,  4.2520375 ,  1.0931191 ,\n",
       "         0.77652   ,  9.312     ,  3.038098  , 25.522808  , 10.041627  ,\n",
       "         4.0208244 ,  1.4567091 ,  0.779971  ,  8.682292  ,  2.5960407 ,\n",
       "        25.318005  ,  9.248054  ,  4.045926  ,  1.7158421 ,  0.82549214,\n",
       "         7.7777796 ,  2.616672  , 25.033335  ,  8.437286  ,  3.899671  ,\n",
       "         1.6119858 ,  0.821985  ,  7.0049305 ,  2.5203104 , 24.626888  ,\n",
       "         8.839753  ,  4.105016  ,  1.7098855 ,  0.82516253,  7.3333654 ,\n",
       "         2.8516943 , 24.77245   ,  9.711424  ,  4.28503   ,  2.0725615 ,\n",
       "         0.9125666 ,  7.6130066 ,  2.8810837 , 24.521465  , 10.09363   ,\n",
       "         4.5307975 ,  2.1585398 ,  0.936623  ,  7.933614  ,  3.055725  ,\n",
       "        23.676104  , 10.452117  ,  4.783902  ,  2.1145637 ,  0.92233217,\n",
       "         8.417807  ,  3.3903277 , 23.298264  ,  9.692848  ,  4.5316744 ,\n",
       "         1.9641473 ,  0.9540007 ,  7.7680893 ,  3.0913033 , 22.426748  ,\n",
       "         9.253598  ,  4.7156897 ,  1.7622936 ,  0.92439735,  7.504118  ,\n",
       "         3.3047962 , 21.57502   ,  9.340029  ,  4.4564953 ,  1.7061318 ,\n",
       "         0.87354803,  7.5569506 ,  3.2197056 , 20.730991  ,  9.669415  ,\n",
       "         4.2411222 ,  1.8282549 ,  0.8589417 ,  7.7603645 ,  3.001691  ,\n",
       "        20.300604  , 10.24262   ,  4.48209   ,  1.9130459 ,  0.87324536,\n",
       "         8.176726  ,  3.0741835 , 20.033218  , 10.431677  ,  4.2994823 ,\n",
       "         2.0366917 ,  0.9470297 ,  8.132969  ,  2.8900185 , 20.469889  ,\n",
       "        10.352656  ,  4.8578496 ,  2.0966773 ,  1.0325212 ,  8.158232  ,\n",
       "         3.1899695 , 21.493134  , 10.50919   ,  5.260278  ,  2.2267072 ,\n",
       "         1.1208985 ,  7.9971666 ,  3.5411735 , 22.34074   , 10.068594  ,\n",
       "         5.1430774 ,  2.2474127 ,  1.202215  ,  7.586275  ,  3.5101178 ,\n",
       "        23.019749  ,  9.6509495 ,  5.123786  ,  2.4760475 ,  1.2286806 ,\n",
       "         7.056637  ,  3.5256596 , 23.579021  , 10.561832  ,  5.018241  ,\n",
       "         3.0211308 ,  1.307111  ,  7.119049  ,  3.3382998 , 23.042793  ,\n",
       "        11.242623  ,  4.8256283 ,  3.7437994 ,  1.3918667 ,  7.1328316 ,\n",
       "         2.9010665 , 23.111992  , 11.037678  ,  4.459323  ,  3.889421  ,\n",
       "         1.3958457 ,  6.842395  ,  2.6778774 ]], dtype=float32)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_pred"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_zs37",
   "language": "python",
   "name": "zs37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
